import numpy as np
import copy
import torch.nn as nn
import torch
from Network.network_utils import run_optimizer, pytorch_model
from ActualCausal.Train.train_utils import get_done_flags, aggregate_result
from ActualCausal.Utils.run_dataset import get_operation, compute_types
from tianshou.data import Batch

def hot_traces(traces, num_clusters):
    hots = list()
    for i in range(num_clusters):
        a = np.zeros(num_clusters)
        a[i] = 1
        hots.append(a)
    seen = list()
    return_traces = list()
    at = 0
    for tr in traces:
        not_seen = True
        for i, v in enumerate(seen):
            if np.linalg.norm(tr - v) < 0.0001:
                not_seen = False
                break
        if not_seen:
            seen.append(copy.deepcopy(tr))
            return_traces.append(copy.deepcopy(hots[at]))
            at += 1
        else:
            return_traces.append(copy.deepcopy(hots[i]))
    return np.array(return_traces)

def train_binaries(args, params, model, buffer, traces, form="probs", log_batch=[], additional=[], keep_all=False, itr_num=0, intermediate_logger = None):
    binaries = buffer.sample(0)[0].weight_binary
    inter_loss = nn.BCELoss()
    results = Batch()
    for i in range(args.active.trace_steps):
        # get the input and target values
        batch, idxes = buffer.sample(args.train.batch_size, params.trace_weights if args.active.trace.use_trace_weights else None)
        trace = traces[idxes].reshape(len(batch), -1)# in the all mode, this should be the full trace
        trace = np.clip(trace, args.active.trace.soft_val, 1.0 - args.active.trace.soft_val)
        trace = batch.valid * trace # zero out invalid values
        # get the network outputs
        # outputs the binary over all instances, in order of names, instance number
        result = model.infer(batch, batch.valid, [form], log_batch=log_batch, additional=additional, keep_all=keep_all)[form]
        done_flags = result.omit_flags

        # compute loss
        trace_target = pytorch_model.wrap(trace, cuda = result.mask_logits.is_cuda)
        if len(trace_target.shape) != result.mask_logits.shape: trace_target = trace_target.unsqueeze(1)
        if not keep_all: trace_target=trace_target[done_flags[0]]
        # TODO: requires trace or utrace in all model outputs
        true_trace = pytorch_model.wrap(result.utrace, cuda = result.mask_logits.is_cuda) if "utrace" in result else result.trace
        
        
        result.trace_true_diff = pytorch_model.unwrap(torch.mean(torch.abs(result.mask_logits - true_trace), dim=0))
        result.trace_diff = pytorch_model.unwrap(torch.mean(torch.abs(result.mask_logits - trace_target), dim=0))
        result.trace_loss = inter_loss(result.mask_logits, trace_target)
        # print(i, result.mask_logits[0], trace_target[0], result.trace_true_diff.shape, result.mask_logits.shape, true_trace.shape, trace_target.shape, result.trace_loss.shape)
        # done corrected traces
        grad_variables = [result.inter_input] if args.active.include_gradient else list()
        compute_model, optim = model.get_model_optim(["all_inter" if form.find("all") != -1 else "full_inter"])
        result.gradients = run_optimizer(optim[0], compute_model[0], result.trace_loss, grad_variables=grad_variables)
        results = aggregate_result(results, result, i, combine_type="cat0", cast_numpy=True)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.trace_steps + i, {"binaries": result}, intermediate_name = "_binaries")
    return results

def train_trace(args, params, model, buffer, form="probs", log_batch=[], additional=[], keep_all=False, itr_num=0, intermediate_logger = None):
    # get the target binaries to train for (the traces, or a proxy like gradient, proximity, etc.)
    # then call train_binaries to train the interaction model to predict those values
    indices=None
    traces = buffer.sample(0)[0].trace
    if args.full_inter.selection_train == "separate":
        traces = hot_traces(traces, args.interaction_net.cluster.num_clusters)
    elif args.full_inter.selection_train == "softened":
        SOFT_EPSILON = 0.2 
        traces = np.clip(traces, SOFT_EPSILON,1.0 - SOFT_EPSILON)
    elif args.full_inter.selection_train == "random":
        traces = np.clip(traces + (np.random.binomial(p=0.1, size=traces.shape)  * ((np.random.randint(0,2) - 0.5) * 2)), 0,1)
    elif args.full_inter.selection_train == "random_ones":
        traces = np.clip(traces + (np.random.binomial(p=0.1, size=traces.shape), 0,1))
    elif args.full_inter.selection_train == "proximity":
        traces = buffer.sample(0)[0].proximity
    elif args.full_inter.selection_train == "gradient":
        THRESHOLD = 0
        inp_grad_vals = get_operation(model, buffer, eval_type=compute_types.LIKELIHOOD_GRADIENT)
        traces[inp_grad_vals > THRESHOLD] = 1 
        traces[inp_grad_vals <= THRESHOLD] = 0 
    if args.inter.interaction.subset_training > 0: # replace the subsets with the values
        # args subset_training is for training only a subset of the data with the true labels, simulating semi-supervised
        rollouts, indices = rollouts.sample(args.inter.interaction.subset_training)
        if object_rollout is not None: object_rollout = object_rollout[indices]
        else: rollout = rollout[indices]
        traces = traces[indices] if len(args.inter.interaction.selection_train) > 0 else None
        weights = weights[indices] / np.sum(weights[indices])
    #### weights the values (above)
    
    # training
    return train_binaries(args, params, model, buffer, traces, form=form, log_batch=log_batch, additional=additional, keep_all=keep_all, itr_num=itr_num, intermediate_logger = intermediate_logger)        